
#include "src/configuration.hpp"
#include "src/model_set.hpp"
#include "src/optimizer.hpp"
void Optimizer::models(std::unordered_set< Model > & results) {
    if (Configuration::model_limit == 0) { return; }
    std::unordered_set< std::shared_ptr<Model>, std::hash< std::shared_ptr<Model> >, std::equal_to< std::shared_ptr<Model> > > local_results;
    results_t new_results = results_t();
    if (rashomon_flag) {
        string task_description;

                auto extraction_start = std::chrono::high_resolution_clock::now(); // Start measuring training time
        task_description = "rash_models Duration: ";
        rash_models(this -> root, new_results, rashomon_bound);
                auto extraction_stop = std::chrono::high_resolution_clock::now(); // Stop measuring training time
                if (Configuration::verbose) {
                    float time = std::chrono::duration_cast<std::chrono::milliseconds>(extraction_stop - extraction_start).count() / 1000.0;
                    std::cout << task_description << time << " seconds" << std::endl;
                }

        if (Configuration::objective_model_set) {
            extraction_start = std::chrono::high_resolution_clock::now(); // Start measuring training time

            task_description = "Output of objective Rashomon Set in Model Set: ";

            std::string serialization;
            ModelSet::serialize(new_results, serialization, 0);
            
            std::string file_name = "model_set-objective-" + Configuration::rashomon_model_set;

            if(Configuration::verbose) { std::cout << "Storing Models in: " << file_name << std::endl; }
            std::ofstream out(file_name);
            out << serialization;
            out.close();

                    extraction_stop = std::chrono::high_resolution_clock::now(); // Stop measuring training time
                    if (Configuration::verbose) {
                        float time = std::chrono::duration_cast<std::chrono::milliseconds>(extraction_stop - extraction_start).count() / 1000.0;
                        std::cout << task_description << time << " seconds" << std::endl;
                    }

        }
        if (Configuration::covered_sets.size() != 0) {

                    extraction_start = std::chrono::high_resolution_clock::now(); // Start measuring training time


            task_description = "Construction of (TP, TN, #Leaves) Duration: ";
            for (auto obj : new_results.first) {
                new_results.second[obj]->get_values_of_interest_count();
            }

                    extraction_stop = std::chrono::high_resolution_clock::now(); // Stop measuring training time
                    if (Configuration::verbose) {
                        float time = std::chrono::duration_cast<std::chrono::milliseconds>(extraction_stop - extraction_start).count() / 1000.0;
                        std::cout << task_description << time << " seconds" << std::endl;
                    }

            unsigned int P, N;
            State::dataset.get_total_P_N(P, N);

            for (int i = 0; i < Configuration::covered_sets.size(); i++) {

                // Common data for this task
                CoveredSetExtraction covered_sets_type = Configuration::covered_sets[i];
                std::string covered_sets_type_string = Configuration::covered_set_type_to_string(covered_sets_type);
                double limit = Configuration::covered_sets_thresholds[i];

                        extraction_start = std::chrono::high_resolution_clock::now(); // Start measuring training time

                task_description = "Extraction of " + covered_sets_type_string + " Rashomon Set: ";


                values_of_interest_mapping_t mapping_all;

                unsigned long long model_count = 0;

                for (auto obj : new_results.first) {
                    auto output = new_results.second[obj]->get_values_of_interest_count();
                    bool extract = false;
                    for (auto i : output) {
                        auto values_of_interest = i.first;

                        double TP = values_of_interest.TP;
                        double TN = values_of_interest.TN;
                        auto reg = values_of_interest.regularization;

                        double metric = Configuration::computeScore(covered_sets_type, P, N, TP, TN);
                        double obj_value = 1 - metric + Configuration::regularization * reg;

                        if (obj_value <= limit) {
                            model_count += i.second;
                            extract = true;
                        }
                    }
                    // Here as I believe `get_values_of_interest_mapping` is a more costly process
                    if (extract) {
                        auto mapping = new_results.second[obj]->get_values_of_interest_mapping();
                        for (auto i : mapping) {
                            auto values_of_interest = i.first;

                            double TP = values_of_interest.TP;
                            double TN = values_of_interest.TN;
                            auto reg = values_of_interest.regularization;

                         
                            double metric = Configuration::computeScore(covered_sets_type, P, N, TP, TN);
                            double obj_value = 1 - metric + Configuration::regularization * reg;
                            
                            // assert(output.at(values_of_interest) == i.second->get_stored_model_count());

                            if (obj_value <= limit) {
                                auto existing_model_set = mapping_all.find(values_of_interest);
                                if (existing_model_set == mapping_all.end()) {
                                    mapping_all.insert(i);
                                } else {
                                    existing_model_set->second->merge(i.second);
                                }
                            }
                        }
                    }
                }
                std::cout << "Size of " + covered_sets_type_string + " Rashomon Set: " << model_count << std::endl;

                        extraction_stop = std::chrono::high_resolution_clock::now(); // Stop measuring training time
                        if (Configuration::verbose) {
                            float time = std::chrono::duration_cast<std::chrono::milliseconds>(extraction_stop - extraction_start).count() / 1000.0;
                            std::cout << task_description << time << " seconds" << std::endl;
                        }
                        extraction_start = std::chrono::high_resolution_clock::now(); // Start measuring training time

                task_description = "Output of " + covered_sets_type_string + " Rashomon Set in Model Set: ";

                std::string serialization;
                ModelSet::serialize(mapping_all, serialization, 0);
                
                std::string file_name = "model_set-" + covered_sets_type_string + "-" + Configuration::rashomon_model_set;

                if(Configuration::verbose) { std::cout << "Storing Models in: " << file_name << std::endl; }
                std::ofstream out(file_name);
                out << serialization;
                out.close();

                        extraction_stop = std::chrono::high_resolution_clock::now(); // Stop measuring training time
                        if (Configuration::verbose) {
                            float time = std::chrono::duration_cast<std::chrono::milliseconds>(extraction_stop - extraction_start).count() / 1000.0;
                            std::cout << task_description << time << " seconds" << std::endl;
                        }
                        extraction_start = std::chrono::high_resolution_clock::now(); // Start measuring training time

            }
        }
        
        if (Configuration::rashomon_trie != "") {

                extraction_start = std::chrono::high_resolution_clock::now(); // Start measuring training time

        task_description = "Insersion of Rashomon Set into Trie: ";

        // bool calculate_size = false;
        // char const *type = "node";
        // Trie* tree = new Trie(calculate_size, type);
        // tree->insert_root();

        // for (auto obj : new_results.first) {
        //     tree->insert_model_set(new_results.second[obj]);
        // }

        // for (auto obj : new_results.first) {
        //     for (auto i : new_results.second[obj]->get_values_of_interest_mapping()) {
        //         auto values_of_interest = i.first;

        //         float TP = std::get<0>(values_of_interest);
        //         float TN = std::get<1>(values_of_interest);
        //         auto reg = std::get<2>(values_of_interest);

        //      
        //         // float bacc = (TP / P + TN / N) / 2;
        //         // float obj_value = 1 - bacc + Configuration::regularization * reg;

        //         float precision = TP / P;
        //         float recall = TP / (TP + N - TN);
        //         float f1 = 2 * precision * recall / (precision + recall);
        //         float obj_value = 1 - f1 + Configuration::regularization * reg;

        //         // if (obj_value < 0.2890741701244814) {
        //         // if (obj_value < 0.3534467120181406) {
        //         if (obj_value < bacc_limit) {
        //             tree->insert_model_set(i.second);
        //         }
        //     }
        // }

                extraction_stop = std::chrono::high_resolution_clock::now(); // Stop measuring training time
                if (Configuration::verbose) {
                    float time = std::chrono::duration_cast<std::chrono::milliseconds>(extraction_stop - extraction_start).count() / 1000.0;
                    std::cout << task_description << time << " seconds" << std::endl;
                }
        }
                extraction_start = std::chrono::high_resolution_clock::now(); // Start measuring training time

        task_description = "Output of Rashomon Set in Trie: ";

        // std::string serialization_1;
        // tree->serialize(serialization_1, 0);
        // // std::cout << serialization_1 << std::endl;

        // if(Configuration::verbose) { std::cout << "Storing Models in: " << Configuration::rashomon_trie + "-new" << std::endl; }
        // std::ofstream out_1(Configuration::rashomon_trie + "-new");
        // out_1 << serialization_1;
        // out_1.close();


                extraction_stop = std::chrono::high_resolution_clock::now(); // Stop measuring training time
                if (Configuration::verbose) {
                    float time = std::chrono::duration_cast<std::chrono::milliseconds>(extraction_stop - extraction_start).count() / 1000.0;
                    std::cout << task_description << time << " seconds" << std::endl;
                }

        task_description = "Past Extraction Method: ";

                extraction_start = std::chrono::high_resolution_clock::now(); // Start measuring training time
        // models(this -> root, local_results, rashomon_bound);
                extraction_stop = std::chrono::high_resolution_clock::now(); // Stop measuring training time
                if (Configuration::verbose) {
                    float time = std::chrono::duration_cast<std::chrono::milliseconds>(extraction_stop - extraction_start).count() / 1000.0;
                    std::cout << task_description << time << " seconds" << std::endl;
                }

    } else {
        models(this -> root, local_results);
    }
    // std::cout << "Local Size: " << local_results.size() << std::endl;
    // std::cout << "Result Size: " void Optimizer::models(std::unordered_set< Model > & results) {<< results.size() << std::endl;
    
    std::cout << "Memory usage: " << getCurrentRSS() / 1000000 << std::endl;
    std::cout << "Cached subproblem models size: " << State::graph.models.size() << std::endl;
    std::cout << "Models calls: " << models_calls << std::endl;
    std::cout << "Pruned combinations using scope: " << pruned_combinations_with_scope << std::endl;
    std::cout << "Pruned leaves using scope: " << pruned_leaves_with_scope << std::endl;
    std::cout << "Pruned trivial extensions: " << pruned_trivial_extension << std::endl;
    std::cout << "Max results size: " << max_result_size << std::endl;
    std::cout << "Re-explore by scope count: " << re_explore_by_scope_update_count << std::endl;
    std::cout << "Re-explore count: " << re_explore_count << std::endl;


    if (rashomon_flag) {   
        std::cout << "Stored keys size: " << new_results.first.size() << std::endl;
        boost::multiprecision::uint128_t models_count = 0;
        for (auto model_set : new_results.second) {
            models_count += model_set.second->get_stored_model_count();
        }
        std::cout << "Size of Rashomon Set: " << models_count << std::endl;
    }
    // Copy into final results
    if (model_limit_exceeded) {
        std::cout << "Model limit exceeded. Will not produce any model." << std::endl;
        results.clear();
        return;
    }
    int count = 0;
    for (auto iterator = local_results.begin(); iterator != local_results.end(); ++iterator) {

        // std::pair< std::unordered_set<Model>::iterator, bool > insertion = results.insert(Model());
        // * (insertion.first) = (** iterator);
         Model * model = new Model(**iterator);
         count++;
        
        if (rashomon_flag) {    
            if (model->loss() + model->complexity() <= rashomon_bound + std::numeric_limits<float>::epsilon()) {
                // std::string serialization;
                // (**iterator).serialize(serialization, 2);
                // std::cout << serialization << std::endl;
                // std::cout << count << std::endl;
                results.insert(**iterator);
            }
            //results.insert(**iterator);

        } else {
            std::string serialization;
            (**iterator).serialize(serialization, 2);
            std::cout << serialization << std::endl;
            results.insert(**iterator);
        }
        delete model;
    }
    //std::cout << "Local Size: " << local_results.size() << std::endl;
    //std::cout << "Result Size: " << results.size() << std::endl;

    //std::cout << "Local Size: " << local_results.size() << std::endl;
    //std::cout << "Result Size: " << results.size() << std::endl;
}

void Optimizer::models(key_type const & identifier, std::unordered_set< std::shared_ptr<Model>, std::hash< std::shared_ptr<Model> >, std::equal_to< std::shared_ptr<Model> > > & results, float scope) {
    // Shortcircuit model extraction if number of models exceeds given amount
    if (model_limit_exceeded) {
        return;
    }

    models_inner(identifier, results, scope);
}

void Optimizer::models_inner(key_type const & identifier, std::unordered_set< std::shared_ptr<Model>, std::hash< std::shared_ptr<Model> >, std::equal_to< std::shared_ptr<Model> > > & results, float scope) {
    vertex_accessor task_accessor;
    if (State::graph.vertices.find(task_accessor, identifier) == false) { return; }
    Task & task = task_accessor -> second;
    //std::cout << "Base Condition: " << task.base_objective() << " <= " << task.upperbound() << " = " << (int)(task.base_objective() <= task.upperbound()) << std::endl;

    // std::cout << "Capture: " << task.capture_set().to_string() << std::endl;
    //std::cout << task.base_objective() << "wawa" << task.rashomon_bound() << "bebe" << task.upperbound() << std::endl;

    if (rashomon_flag) {
        if (task.maximum_scope > 0) {
            re_explore_count++;
        }

        if (task.maximum_scope < scope) {
            if (task.maximum_scope > 0) {
                re_explore_by_scope_update_count++;
            }
            task.maximum_scope = scope;
        }
    }

    if (rashomon_flag) {

        if (task.base_objective() <= scope + std::numeric_limits<float>::epsilon()) {
            // || (Configuration::rule_list && task.capture_set().count() != task.capture_set().size())) {
            // std::cout << "Stump" << std::endl;
            // std::shared_ptr<key_type> stump(new Tile(set));
            // Model stump_key(stump_set); // shallow variant
            // Model * stump_address = new Model(stump_set);
            //std::cout << task.rashomon_bound() << std::endl;

            unsigned int count = task.capture_set().count();

            std::shared_ptr<Model> model(new Model(std::shared_ptr<Bitmask>(new Bitmask(task.capture_set()))));
            model -> identify(identifier);
            model -> translate_self(task.order());
            results.insert(model);
        } else {
            pruned_leaves_with_scope++;
        }

    }
    else { 
        if (task.base_objective() <= task.upperbound() + std::numeric_limits<float>::epsilon()) {
            // || (Configuration::rule_list && task.capture_set().count() != task.capture_set().size())) {
            // std::cout << "Stump" << std::endl;
            // std::shared_ptr<key_type> stump(new Tile(set));
            // Model stump_key(stump_set); // shallow variant
            // Model * stump_address = new Model(stump_set);
            std::shared_ptr<Model> model(new Model(std::shared_ptr<Bitmask>(new Bitmask(task.capture_set()))));
            model -> identify(identifier);
            
            model -> translate_self(task.order());
            results.insert(model);
        }
    }   
    bound_accessor bounds;
    float lower_val, upper_val;
    if (!State::graph.bounds.find(bounds, identifier)) { return; }
    for (bound_iterator iterator = bounds -> second.begin(); iterator != bounds -> second.end(); ++iterator) {

        // if (rashomon_flag) { if (std::get<2>(* iterator) > task.rashomon_bound() + std::numeric_limits<float>::epsilon()) { continue; } }
        if (rashomon_flag) { if (std::get<1>(* iterator) > scope + std::numeric_limits<float>::epsilon()) { continue; } }
        else { if (std::get<2>(* iterator) > task.upperbound() + std::numeric_limits<float>::epsilon()) { continue; } }
        int feature = std::get<0>(* iterator);
        //std::cout << "Feature: " << feature << std::endl;
        std::unordered_set< std::shared_ptr<Model> > negatives;
        std::unordered_set< std::shared_ptr<Model> > positives;
        bool ready = true;

        child_accessor left_key, right_key;
        vertex_accessor left_child, right_child;
        float left_lowerbound = 0, right_lowerbound = 0;

        bool left_has_key = State::graph.children.find(left_key, std::make_pair(identifier, -(feature + 1)));
        bool left_has_child = left_has_key && State::graph.vertices.find(left_child, left_key->second);
        if (left_has_child) {
            left_lowerbound = left_child->second.lowerbound();
            left_child.release();
        } else if (!left_has_key) {
            Bitmask subset(task.capture_set());
            State::dataset.subset(feature, false, subset);
            unsigned int count = subset.count();
            std::shared_ptr<Model> model(new Model(std::shared_ptr<Bitmask>(new Bitmask(subset))));
            float leaf_objective = model->loss() + model->complexity();
            if (rashomon_flag && leaf_objective > scope + std::numeric_limits<float>::epsilon()) {
                pruned_leaves_with_scope++;
                continue;
            }
            left_lowerbound = leaf_objective;
            negatives.insert(model);
        } else {
            continue;
        }

        bool right_has_key = State::graph.children.find(right_key, std::make_pair(identifier, feature + 1));
        bool right_has_child = right_has_key && State::graph.vertices.find(right_child, right_key->second);
        if (right_has_child) {
            right_lowerbound = right_child->second.lowerbound();
            right_child.release();
        } else if (!right_has_key) {
            Bitmask subset(task.capture_set());
            State::dataset.subset(feature, true, subset);
            unsigned int count = subset.count();
            std::shared_ptr<Model> model(new Model(std::shared_ptr<Bitmask>(new Bitmask(subset))));
            float leaf_objective = model->loss() + model->complexity();
            if (rashomon_flag && leaf_objective > scope + std::numeric_limits<float>::epsilon()) {
                pruned_leaves_with_scope++;
                continue;
            }
            right_lowerbound = leaf_objective;
            positives.insert(model);
        } else {
            // might never reach here? 
            continue;
        }

        if (rashomon_flag && (scope - right_lowerbound < 0 || scope - left_lowerbound < 0)) { continue; }


        if (left_has_child) {    
            models(left_key -> second, negatives, scope - right_lowerbound);
            left_key.release();
        }

        if (negatives.size() == 0) { continue; }

        if (right_has_child) {
            models(right_key -> second, positives, scope - left_lowerbound);
            right_key.release();
        } 

        if (positives.size() == 0) { continue; }
        
        if (Configuration::rule_list) {
            throw std::invalid_argument("Does not support rule lists");
        } else {

            for (auto negative_it = negatives.begin(); negative_it != negatives.end(); ++negative_it) {
                for (auto positive_it = positives.begin(); positive_it != positives.end(); ++positive_it) {

                    if (Configuration::model_limit > 0 && results.size() > Configuration::model_limit) { 
                        model_limit_exceeded = true;
                        return;
                    }
                    if (rashomon_flag) {
                        std::shared_ptr<Model> negative_model = (* negative_it);
                        std::shared_ptr<Model> positive_model = (* positive_it);
                        // Prune trivial extensions
                        if (Configuration::rashomon_ignore_trivial_extensions 
                            && negative_model->terminal == positive_model->terminal 
                            && (negative_model->terminal 
                                ? negative_model->get_prediction() == positive_model->get_prediction()
                                // : ((* negative_model) == (* positive_model)))) {
                                : false)) {
                            pruned_trivial_extension++;
                            continue;
                        }
                        // Prune models exceeding maximum allowed objective value 
                        if (negative_model->loss() + negative_model->complexity() + positive_model->loss() + positive_model->complexity() > scope) { 
                            pruned_combinations_with_scope++; 
                            continue; 
                        }
                    }
                    
                    std::shared_ptr<Model> negative(* negative_it);
                    std::shared_ptr<Model> positive(* positive_it);
                    std::shared_ptr<Model> model(new Model(feature, negative, positive));
                    model -> identify(identifier);
                    model -> translate_self(task.order());
                    translation_accessor negative_translation, positive_translation;
                    if ((** negative_it).identified()
                        && State::graph.translations.find(negative_translation, std::make_pair(identifier, -(feature + 1)))) {
                        model -> translate_negatives(negative_translation -> second);
                    }
                    negative_translation.release();
                    if ((** positive_it).identified()
                        && State::graph.translations.find(positive_translation, std::make_pair(identifier, feature + 1))) {
                        model -> translate_positives(positive_translation -> second);
                    }
                    positive_translation.release();
            
                    results.insert(model); 
                    
                }
            }

        }
    }

    max_result_size = std::max(max_result_size, results.size());
    return;
}


